from model.db import db
import json
import copy

from utils.format_util import py_to_java
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")


class Task(db.Model):
	id = db.Column(db.Integer, primary_key=True)
	name = db.Column(db.String(100))
	project_id = db.Column()
	pipeline_id = db.Column()
	parent_id = db.Column()
	user_id = db.Column()
	type = db.Column()
	data_json = db.Column()

	def __init__(self, name, project_id, type, data_json):
		self.name = name
		self.project_id = project_id
		self.type = type
		self.data_json = data_json

	def update_data_json(self, data_json):
		self.data_json = data_json


def update_task(taskId, df, target, update, results={}):
	try:
		task = Task.query.filter_by(id=taskId).first()
		parent_task = Task.query.filter_by(id=task.parent_id).first()
	except:
		db.session.rollback()
		logging.info("session rollback")
		task = Task.query.filter_by(id=taskId).first()
		parent_task = Task.query.filter_by(id=task.parent_id).first()
	data_json = json.loads(task.data_json)
	parent_data_json = json.loads(parent_task.data_json)
	# logging.info(data_json)
	logging.info("PREVIOUS ORIGINAL DATA JSON__________")
	columns_df_pred = df.columns.to_list()
	data_json_output = copy.deepcopy(parent_data_json["output"])
	#data_json["output"] = data_json_output

	if len(data_json_output) == 0:
		return 0
	new_columns = []
	old_columns = copy.deepcopy(data_json_output[0]["tableCols"])
	for new_column in columns_df_pred:
		if new_column not in old_columns:
			new_columns.append(new_column)
	# try:
	#
	# 	data_json_output[0]["tableCols"] = old_columns + new_columns
	# except:
	data_json_output[0]["tableCols"] = columns_df_pred


	semantic_new = {}
	semantic_old = copy.deepcopy(data_json_output[0]["semantic"])
	for key, value in semantic_old.items():
		new_key = key.replace(" ", "_").replace("/", "_").replace("-", "_")
		semantic_new[new_key] = value
	data_json_output[0]["semantic"] = semantic_new
	data_json_output[0]["tableName"] = target
	for key in results:
		data_json[key] = results[key]

	if update:
		for new_column in new_columns:
			data_json_output[0]["semantic"][new_column] = "null"
			index = data_json_output[0]["tableCols"].index(new_column)
			data_json_output[0]["columnTypes"].insert(index,"REAL")
	data_json["output"] = data_json_output
	data_json = py_to_java(str(data_json))
	# logging.info(data_json)
	task.update_data_json(data_json)
	# logging.info(results)
	db.session.commit()
	return 1


def get_col_type(taskId):
	task = Task.query.filter_by(id=taskId).first()
	data_json = json.loads(task.data_json)
	cols = data_json["input"][0]["tableCols"]
	types = data_json["input"][0]["columnTypes"]
	ret = {}
	for col, type in zip(cols, types):
		if col == "_record_id_":
			continue
		ret[col] = type
	return ret


def filter_numeric_col(taskId):
	col_type = get_col_type(taskId)
	numeric_cols = []
	non_numeric_cols = []
	for col in col_type:
		if col_type[col] in ["BIGINT", "INTEGER", "DECIMAL", "FLOAT", "SMALLINT", "NUMERIC", "DOUBLE", "REAL"]:
			numeric_cols.append(col)
		else:
			non_numeric_cols.append(col)
	return numeric_cols, non_numeric_cols



